"""recur_mlp.py
Recurrent mulit-layer Perceptron pytorch model class.
"""
import torch
import torch.nn as nn


class FullyConnectedBlock(nn.Module):
    def __init__(self, width, bn=False):
        super().__init__()
        self.linear = nn.Linear(width, width, bias=not bn)
        self.bn = bn
        if bn:
            self.bn_layer = nn.BatchNorm1d(width)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.linear(x)
        if self.bn:
            out = self.bn_layer(x)
        return self.relu(out)


class RecurMLP(nn.Module):
    def __init__(self, block=FullyConnectedBlock, num_inputs=32*32*3, num_outputs=1, width=1000,
                 depth=5, bn=False):
        super().__init__()
        self.block = block
        self.depth = depth
        self.num_outputs = num_outputs
        self.iters = depth - 2
        self.bn = bn
        self.linear_first = nn.Linear(num_inputs, width, bias=not bn)
        if bn:
            self.bn_first = nn.BatchNorm1d(width)
        self.relu = nn.ReLU()
        self.layers = self._make_layer(block, width, 1, bn)
        if bn:
            self.bn_second = nn.BatchNorm1d(width)
        self.linear_last = nn.Linear(width, num_outputs)

    def _make_layer(self, block, width, depth, bn):
        layers = []
        for i in range(depth):
            layers.append(block(width, bn=bn))
        return nn.Sequential(*layers)

    def forward(self, x):
        self.thoughts = torch.zeros((self.iters, x.shape[0], self.num_outputs)).to(x.device)

        out = x.view(x.size(0), -1)
        out = self.linear_first(out)
        if self.bn:
            out = self.bn_first(out)
        out = self.relu(out)

        for i in range(self.iters):
            out = self.layers(out)
            self.thoughts[i] = self.linear_last(out)
        return self.thoughts[-1]


def recur_mlp_100_3(num_outputs=10):
    return RecurMLP(width=100, depth=3, num_outputs=num_outputs)


def recur_mlp_100_4(num_outputs=10):
    return RecurMLP(width=100, depth=4, num_outputs=num_outputs)


def recur_mlp_100_5(num_outputs=10):
    return RecurMLP(width=100, depth=5, num_outputs=num_outputs)


def recur_mlp_100_6(num_outputs=10):
    return RecurMLP(width=100, depth=6, num_outputs=num_outputs)


def recur_mlp_100_7(num_outputs=10):
    return RecurMLP(width=100, depth=7, num_outputs=num_outputs)


def recur_mlp_100_8(num_outputs=10):
    return RecurMLP(width=100, depth=8, num_outputs=num_outputs)


def recur_mlp_100_9(num_outputs=10):
    return RecurMLP(width=100, depth=9, num_outputs=num_outputs)


def recur_mlp_100_10(num_outputs=10):
    return RecurMLP(width=100, depth=10, num_outputs=num_outputs)


def recur_mlp_200_3(num_outputs=10):
    return RecurMLP(width=200, depth=3, num_outputs=num_outputs)


def recur_mlp_200_4(num_outputs=10):
    return RecurMLP(width=200, depth=4, num_outputs=num_outputs)


def recur_mlp_200_5(num_outputs=10):
    return RecurMLP(width=200, depth=5, num_outputs=num_outputs)


def recur_mlp_200_6(num_outputs=10):
    return RecurMLP(width=200, depth=6, num_outputs=num_outputs)


def recur_mlp_200_7(num_outputs=10):
    return RecurMLP(width=200, depth=7, num_outputs=num_outputs)


def recur_mlp_200_8(num_outputs=10):
    return RecurMLP(width=200, depth=8, num_outputs=num_outputs)


def recur_mlp_200_9(num_outputs=10):
    return RecurMLP(width=200, depth=9, num_outputs=num_outputs)


def recur_mlp_200_10(num_outputs=10):
    return RecurMLP(width=200, depth=10, num_outputs=num_outputs)


def recur_mlp_250_3(num_outputs=10):
    return RecurMLP(width=250, depth=3, num_outputs=num_outputs)


def recur_mlp_250_4(num_outputs=10):
    return RecurMLP(width=250, depth=4, num_outputs=num_outputs)


def recur_mlp_250_5(num_outputs=10):
    return RecurMLP(width=250, depth=5, num_outputs=num_outputs)


def recur_mlp_250_6(num_outputs=10):
    return RecurMLP(width=250, depth=6, num_outputs=num_outputs)


def recur_mlp_250_7(num_outputs=10):
    return RecurMLP(width=250, depth=7, num_outputs=num_outputs)


def recur_mlp_250_8(num_outputs=10):
    return RecurMLP(width=250, depth=8, num_outputs=num_outputs)


def recur_mlp_250_9(num_outputs=10):
    return RecurMLP(width=250, depth=9, num_outputs=num_outputs)


def recur_mlp_250_10(num_outputs=10):
    return RecurMLP(width=250, depth=10, num_outputs=num_outputs)


def recur_mlp_500_3(num_outputs=10):
    return RecurMLP(width=500, depth=3, num_outputs=num_outputs)

def recur_mlp_500_4(num_outputs=10):
    return RecurMLP(width=500, depth=4, num_outputs=num_outputs)


def recur_mlp_500_5(num_outputs=10):
    return RecurMLP(width=500, depth=5, num_outputs=num_outputs)


def recur_mlp_500_6(num_outputs=10):
    return RecurMLP(width=500, depth=6, num_outputs=num_outputs)


def recur_mlp_500_7(num_outputs=10):
    return RecurMLP(width=500, depth=7, num_outputs=num_outputs)


def recur_mlp_1000_3(num_outputs=10):
    return RecurMLP(width=1000, depth=3, num_outputs=num_outputs)


def recur_mlp_1000_4(num_outputs=10):
    return RecurMLP(width=1000, depth=4, num_outputs=num_outputs)


def recur_mlp_1000_5(num_outputs=10):
    return RecurMLP(width=1000, depth=5, num_outputs=num_outputs)


def recur_mlp_1000_6(num_outputs=10):
    return RecurMLP(width=1000, depth=6, num_outputs=num_outputs)


def recur_mlp_1000_7(num_outputs=10):
    return RecurMLP(width=1000, depth=7, num_outputs=num_outputs)


def recur_mlp_250_4_bn(num_outputs=10):
    return RecurMLP(width=250, depth=4, num_outputs=num_outputs, bn=True)


def recur_mlp_250_5_bn(num_outputs=10):
    return RecurMLP(width=250, depth=5, num_outputs=num_outputs, bn=True)


def recur_mlp_250_6_bn(num_outputs=10):
    return RecurMLP(width=250, depth=6, num_outputs=num_outputs, bn=True)


def recur_mlp_250_7_bn(num_outputs=10):
    return RecurMLP(width=250, depth=7, num_outputs=num_outputs, bn=True)


def recur_mlp_500_4_bn(num_outputs=10):
    return RecurMLP(width=500, depth=4, num_outputs=num_outputs, bn=True)


def recur_mlp_500_5_bn(num_outputs=10):
    return RecurMLP(width=500, depth=5, num_outputs=num_outputs, bn=True)


def recur_mlp_500_6_bn(num_outputs=10):
    return RecurMLP(width=500, depth=6, num_outputs=num_outputs, bn=True)


def recur_mlp_500_7_bn(num_outputs=10):
    return RecurMLP(width=500, depth=7, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_4_bn(num_outputs=10):
    return RecurMLP(width=1000, depth=4, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_5_bn(num_outputs=10):
    return RecurMLP(width=1000, depth=5, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_6_bn(num_outputs=10):
    return RecurMLP(width=1000, depth=6, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_7_bn(num_outputs=10):
    return RecurMLP(width=1000, depth=7, num_outputs=num_outputs, bn=True)



def recur_mlp_20_3_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=3, num_outputs=num_outputs)


def recur_mlp_20_4_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=4, num_outputs=num_outputs)


def recur_mlp_20_5_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=5, num_outputs=num_outputs)


def recur_mlp_20_6_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=6, num_outputs=num_outputs)


def recur_mlp_20_7_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=7, num_outputs=num_outputs)


def recur_mlp_20_8_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=8, num_outputs=num_outputs)


def recur_mlp_20_9_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=9, num_outputs=num_outputs)


def recur_mlp_20_10_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=20, depth=10, num_outputs=num_outputs)

def recur_mlp_100_4_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=4, num_outputs=num_outputs)


def recur_mlp_100_5_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=5, num_outputs=num_outputs)


def recur_mlp_100_6_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=6, num_outputs=num_outputs)


def recur_mlp_100_7_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=7, num_outputs=num_outputs)


def recur_mlp_100_8_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=8, num_outputs=num_outputs)


def recur_mlp_100_9_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=9, num_outputs=num_outputs)


def recur_mlp_100_10_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=100, depth=10, num_outputs=num_outputs)


def recur_mlp_200_3_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=3, num_outputs=num_outputs)


def recur_mlp_200_4_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=4, num_outputs=num_outputs)


def recur_mlp_200_5_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=5, num_outputs=num_outputs)


def recur_mlp_200_6_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=6, num_outputs=num_outputs)


def recur_mlp_200_7_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=7, num_outputs=num_outputs)


def recur_mlp_200_8_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=8, num_outputs=num_outputs)


def recur_mlp_200_9_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=9, num_outputs=num_outputs)


def recur_mlp_200_10_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=200, depth=10, num_outputs=num_outputs)


def recur_mlp_250_3_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=3, num_outputs=num_outputs)


def recur_mlp_250_4_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=4, num_outputs=num_outputs)


def recur_mlp_250_5_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=5, num_outputs=num_outputs)


def recur_mlp_250_6_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=6, num_outputs=num_outputs)


def recur_mlp_250_7_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=7, num_outputs=num_outputs)


def recur_mlp_250_8_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=8, num_outputs=num_outputs)


def recur_mlp_250_9_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=9, num_outputs=num_outputs)


def recur_mlp_250_10_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=10, num_outputs=num_outputs)


def recur_mlp_500_4_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=4, num_outputs=num_outputs)


def recur_mlp_500_5_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=5, num_outputs=num_outputs)


def recur_mlp_500_6_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=6, num_outputs=num_outputs)


def recur_mlp_500_7_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=7, num_outputs=num_outputs)


def recur_mlp_1000_3_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=3, num_outputs=num_outputs)


def recur_mlp_1000_4_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=4, num_outputs=num_outputs)


def recur_mlp_1000_5_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=5, num_outputs=num_outputs)


def recur_mlp_1000_6_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=6, num_outputs=num_outputs)


def recur_mlp_1000_7_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=7, num_outputs=num_outputs)


def recur_mlp_250_4_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=4, num_outputs=num_outputs, bn=True)


def recur_mlp_250_5_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=5, num_outputs=num_outputs, bn=True)


def recur_mlp_250_6_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=6, num_outputs=num_outputs, bn=True)


def recur_mlp_250_7_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=250, depth=7, num_outputs=num_outputs, bn=True)


def recur_mlp_500_4_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=4, num_outputs=num_outputs, bn=True)


def recur_mlp_500_5_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=5, num_outputs=num_outputs, bn=True)


def recur_mlp_500_6_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=6, num_outputs=num_outputs, bn=True)


def recur_mlp_500_7_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=7, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_4_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=4, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_5_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=5, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_6_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=6, num_outputs=num_outputs, bn=True)


def recur_mlp_1000_7_bn_mnist(num_outputs=10):
    return RecurMLP(num_inputs=1*28*28, width=1000, depth=7, num_outputs=num_outputs, bn=True)


def recur_mlp_500_3_emnist(num_outputs=47):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=3, num_outputs=num_outputs)


def recur_mlp_500_4_emnist(num_outputs=47):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=4, num_outputs=num_outputs)


def recur_mlp_500_5_emnist(num_outputs=47):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=5, num_outputs=num_outputs)


def recur_mlp_500_6_emnist(num_outputs=47):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=6, num_outputs=num_outputs)


def recur_mlp_500_7_emnist(num_outputs=47):
    return RecurMLP(num_inputs=1*28*28, width=500, depth=7, num_outputs=num_outputs)